-
Notifications
You must be signed in to change notification settings - Fork 764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[#432] Add Groq Provider - chat completions #609
Conversation
elif finish_reason == "length": | ||
return StopReason.end_of_message | ||
elif finish_reason == "tool_calls": | ||
raise NotImplementedError("tool_calls is not supported yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users won't be able to hit this error yet since they can't pass tools as a parameter
"remote::groq", | ||
): | ||
pytest.skip(provider.__provider_spec__.provider_type + " doesn't support tool calling yet") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per comment above: https://github.com/meta-llama/llama-stack/pull/609/files#r1881443869
Will remove after I implement
warnings.warn("repetition_penalty is not supported") | ||
|
||
if request.tools: | ||
warnings.warn("tools are not supported yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m planning to handle tool calls in a separate PR since there are edge cases I want to cover properly.
But lmk if you want me to include it within this PR
8fa0bae
to
98e3563
Compare
ec8b47b
to
3f2498e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good!
@json_schema_type | ||
class GroqConfig(BaseModel): | ||
api_key: Optional[str] = Field( | ||
default=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - the groq library will read GROQ_API_KEY env (https://github.com/groq/groq-python/blob/main/src/groq/_client.py#L86), consider adding a comment here so people in the LS codebase know this expectation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not rely on environment variables for code that we expect to run in llama-stack server. We would want to take in the api key as a config variable in run.yaml when we spin up the server
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would want to take in the api key as a config variable in run.yaml when we spin up the server
@raghotham, I believe that's the behaviour at the moment. This is how fireworks and together define their configs:
api_key: Optional[str] = Field( | |
default=None, | |
description="The Fireworks.ai API Key", | |
) |
llama-stack/llama_stack/providers/remote/inference/together/config.py
Lines 19 to 22 in 6765fd7
api_key: Optional[str] = Field( | |
default=None, | |
description="The Together AI API Key", | |
) |
And it's in the run.yaml that you define the environment variable:
llama-stack/llama_stack/templates/together/run.yaml
Lines 18 to 20 in 516e1a3
config: | |
url: https://api.together.xyz/v1 | |
api_key: ${env.TOGETHER_API_KEY} |
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes @aidando73 this is correct. Note that both Together and Fireworks also support grabbing the api key from headers via the NeedsProviderData
mixin. You can add that if you feel like it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - added the mixin
Added some client code in the test plan as well to test
]: | ||
|
||
if model_id == "llama-3.2-3b-preview": | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very user friendly +1
logprobs=None, | ||
frequency_penalty=None, | ||
stream=request.stream, | ||
# Groq only supports n=1 at the time of writing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the LS structures only support responses w/ 1 choice, so the fact that groq only supports n=1 is moot. you should even skip passing the param around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
raise ValueError(f"Invalid finish reason: {finish_reason}") | ||
|
||
|
||
async def convert_chat_completion_response_stream( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in some other PR, we should merge this into the general util module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think this function is too coupled to Groq types to be used as general util function? E.g., this one takes in a ChatCompletionChunk
from groq.types.chat.chat_completion_chunk
2f1522a
to
3587f08
Compare
|
||
if request.logprobs: | ||
# Groq doesn't support logprobs at the time of writing | ||
warnings.warn("logprobs are not supported yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
if request.response_format: | ||
# Groq's JSON mode is beta at the time of writing | ||
warnings.warn("response_format is not supported yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference: | ||
# import dynamically so `llama stack build` does not fail due to missing dependencies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the correct comment should be import dynamically so the import is used only when it is needed
:D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
CoreModelId.llama3_70b_instruct.value, | ||
), | ||
build_model_alias( | ||
"llama-3.3-70b-versatile", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know what does this suffix indicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find anything online. @ricklamers @philass - could you provide any additional context here?
|
||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=next(event_types), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems unnecessary to refactor the event type generator separately. just maintain an index here or even more simply just repeat a small amount of code to send back a start chunk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome, thank you very much!
(will merge after a couple small comments are addressed) |
78912e6
to
0f6beb1
Compare
0f6beb1
to
c0757fd
Compare
Ok @ashwinb I've addressed your comments |
Can this PR be merged soon ? |
Thanks for the reminder @cheesecake100201! Sorry was out on vacation and missed this one after coming back. |
# What does this PR do? Contributes towards: #432 RE: #609 I missed this one while refactoring. Fixes: ```python Traceback (most recent call last): File "/Users/aidand/dev/llama-stack/llama_stack/distribution/server/server.py", line 191, in endpoint return await maybe_await(value) File "/Users/aidand/dev/llama-stack/llama_stack/distribution/server/server.py", line 155, in maybe_await return await value File "/Users/aidand/dev/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 101, in async_wrapper result = await method(self, *args, **kwargs) File "/Users/aidand/dev/llama-stack/llama_stack/distribution/routers/routers.py", line 156, in chat_completion return await provider.chat_completion(**params) File "/Users/aidand/dev/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 101, in async_wrapper result = await method(self, *args, **kwargs) File "/Users/aidand/dev/llama-stack/llama_stack/providers/remote/inference/groq/groq.py", line 127, in chat_completion response = self._get_client().chat.completions.create(**request) File "/Users/aidand/dev/llama-stack/llama_stack/providers/remote/inference/groq/groq.py", line 143, in _get_client return Groq(api_key=self.config.api_key) AttributeError: 'GroqInferenceAdapter' object has no attribute 'config'. Did you mean: '_config'? ``` ## Test Plan Environment: ```shell export GROQ_API_KEY=<api-key> # build.yaml and run.yaml files wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/run.yaml # Create environment if not already conda create --prefix ./envs python=3.10 conda activate ./envs # Build pip install -e . && llama stack build --config ./build.yaml --image-type conda # Activate built environment conda activate llamastack-groq ``` <details> <summary>Manual</summary> ```bash llama stack run ./run.yaml --port 5001 ``` Via this Jupyter notebook: https://github.com/aidando73/llama-stack/blob/9165502582cd7cb178bc1dcf89955b45768ab6c1/hello.ipynb </details> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
What does this PR do?
Contributes towards issue (#432)
A lot of inspiration taken from @mattf's good work at #355
What this PR does not do
PR Train
Test Plan
Environment
Manual tests
Using this jupyter notebook to test manually: https://github.com/aidando73/llama-stack/blob/2140976d76ee7ef46025c862b26ee87585381d2a/hello.ipynb
Use this code to test passing in the api key from provider_data
Integration
pytest llama_stack/providers/tests/inference/test_text_inference.py -v -k groq
(run in same environment)
Unit tests
pytest llama_stack/providers/tests/inference/groq/ -v
Before submitting
Pull Request section?